6.2 Entscheidungsbäume visualisieren und trainieren#

Im letzten Kapitel haben wir gelernt, wie mit Scikit-Learn ein Entscheidungsbaum für binäre Klassifikationsaufgaben trainiert wird. In diesem Kapitel werden wir uns damit beschäftigen, den trainierten Entscheidungsbaum von Scikit-Learn visualisieren zu lassen. Darüber hinaus lernen wir, was das Gini-Impurity-Kriterion ist und welche weiteren Einstellmöglichkeiten es für Entscheidungsbäume in Scikit-Learn gibt.

Lernziele#

Lernziele

  • Sie können einen Entscheidungsbaum mit plot_tree visualisieren.

  • Sie wissen, was die Angaben samples und value bei der Visualisierung des Entscheidungsbaumes bedeuten.

  • Sie wissen, was das Gini-Impurity-Kriterium ist.

  • Sie kennen weitere Parameter für Entscheidungsbäume wie random_state= oder criterion=.

Entscheidungsbäume visualisieren#

Im letzten Kapitel haben wir den Entscheidungsbaum für das Autohaus mit Hilfe des Moduls Scikit-Learn trainiert. Scikit-Learn bietet in dem Untermodul sklearn.tree nicht nur Algorithmen für Entscheidungsbäume an, sondern auch ein dazu passendes Visualisierungswerkzeug. Die Funktion plot_tree zeichnet den Entscheidungsbaum. Um diese Funktion auszuprobieren, wird zunächst der Datensatz mit den Autodaten erneut geladen, das Modell Entscheidungsbaum gewählt und anschließend trainiert.

import pandas as pd 
from sklearn.tree import DecisionTreeClassifier

# Sammlung der Daten 
daten = pd.DataFrame({
    'Kilometerstand [km]': [32908, 20328, 13285, 17162, 27449, 13715, 32889,  3111, 15607, 18295],
    'Preis [EUR]': [15960, 20495, 17227, 17851, 5428, 22772, 13581, 16793, 23253, 11382],
    'verkauft': [False, True, False, True, False, True, False, True, True, False],
    },
    index=['Auto 1', 'Auto 2', 'Auto 3', 'Auto 4', 'Auto 5', 'Auto 6', 'Auto 7', 'Auto 8', 'Auto 9', 'Auto 10'])
daten.head(10)

# Auswahl des Modells: Entscheidungsbaum für Klassifikation
modell = DecisionTreeClassifier(random_state=0)

# Adaption der Daten
X = daten[['Kilometerstand [km]', 'Preis [EUR]']]
y = daten['verkauft']

# Training des Modells
modell.fit(X,y)
DecisionTreeClassifier(random_state=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Nun können wir die Funktion plot_tree importieren und das trainierte Modell visualisieren lassen.

from sklearn.tree import plot_tree

plot_tree(modell)
[Text(0.4, 0.875, 'x[1] <= 16376.5\ngini = 0.5\nsamples = 10\nvalue = [5, 5]'),
 Text(0.2, 0.625, 'gini = 0.0\nsamples = 4\nvalue = [4, 0]'),
 Text(0.30000000000000004, 0.75, 'True  '),
 Text(0.6, 0.625, 'x[0] <= 13500.0\ngini = 0.278\nsamples = 6\nvalue = [1, 5]'),
 Text(0.5, 0.75, '  False'),
 Text(0.4, 0.375, 'x[0] <= 8198.0\ngini = 0.5\nsamples = 2\nvalue = [1, 1]'),
 Text(0.2, 0.125, 'gini = 0.0\nsamples = 1\nvalue = [0, 1]'),
 Text(0.6, 0.125, 'gini = 0.0\nsamples = 1\nvalue = [1, 0]'),
 Text(0.8, 0.375, 'gini = 0.0\nsamples = 4\nvalue = [0, 4]')]
../_images/6582383b668e61b0166cc6046e9c4c3c2de5d5ecd0c3d2a380201d02c15fb3bb.png

plot_tree produziert eine Textausgabe und ein Diagramm. Die Textausgabe kann unterdrückt werden, indem hinter den Funktionsaufruf plot_tree(modell) ein Semikolon ; gesetzt wird. Das Diagramm zeichnet wie erwartet die Baumstruktur vom Wurzelknoten über die Knoten und Zweige bis hin zu den Blättern. Die Entscheidungsfragen stehen in der erste Zeile der Knoten. Danach folgen weitere Angaben wie gini, samples und value. Um diese Angaben zu erklären, ergänzen wir zunächst weitere Angaben. Mit der Option feature_names= wird eine Liste mit den Eigenschaften ergänzt, die Option class_names= ergänzt die Klassenbezeichnugnen. So erhalten wir folgendes Diagramm:

plot_tree(modell, 
    feature_names=['Kilometerstand [km]', 'Preis [EUR]'],
    class_names=['nicht verkauft', 'verkauft']);
../_images/3c2a1b45fa247ffaf5bb7fddd1656bae1eb2ecac966f4b305921846c26a33927.png

Was gini bedeuten könnte, erschließt sich so immer noch nicht, aber die Angaben samples und values können so leichter von ihrer Bedeutung her eingeordnet werden. samples gibt die Anzahl der Datenobjekte an, die sich in diesem Knoten befinden. values listet auf, wie viele Datenobjekte die Zielgröße nicht verkauft (= False bzw. 0) haben und wie viele zu der Klasse verkauft (= True bzw. 1) gehören.

Weitere Details zu den Optionen der plot_tree-Funktion finden Sie in der Dokumentation Scikit-Learn → plot_tree.

Als nächstes widmen wir uns der Bedeutung von gini.

Was ist das Gini-Impurity-Kriterium?#

Das Gini-Impurity-Kriterium ist ein Maß für die Unreinheit eines Datensatzes. Beim Beispiel mit dem Autohaus sind im Wurzelknoten fünf Autos, die nicht verkauft wurden, und fünf verkaufte Autos. Bei zwei Klassen ist das die maximale Unreinheit, die auftreten kann. Der Anteil der verkauften Autos ist genau 50 %. Diesem prozentualen Anteil wird das Gini-Impurity-Kriterium von 0.5 zugeordnet. Es gibt zwei weitere Extremfälle. Entweder sind nur verkaufte Autos im Datensatz (100 % verkaufte Autos) oder gar keine verkaufte Autos (0 % verkaufte Autos). In beiden Fällen ist der Datensatz rein, das Gini-Impurity-Kriterium ist 0. In allen anderen Fällen liegt das Gini-Impurity-Kriterium zwischen 0 und 0.5. Die Formel zur Berechnung des genauen Wertes des Gini-Impurity-Kriteriums lautet

\[\text{GI} = 1 - p^2 - (1-p)^2,\]

wenn \(p\) der prozentuale Anteil der verkauften Autos ist (das gilt natürlich allgemein für binäre Klassifikationsaufgaben und nicht nur das Autohaus-Beispiel).

Die folgende Abbildung zeigt die konkreten Werte des Gini-Impurity-Kriteriums für den prozentualen Anteil an verkauften Autos.

from numpy import linspace

p = linspace(0,1)
gini = 1 - p**2 - (1-p)**2

import plotly.express as px

fig = px.line(x = p, y = gini,
        title='Gini-Impurity-Kriterium',
        labels={'x': 'prozentualer Anteil', 'y': 'Wert des Gini-Impurity-Kriteriums'})
fig.show()

Im Diagramm können wir direkt ablesen, dass bei einem nicht verkauften Auto und fünf verkauften Autos (\(p = 0.8\bar{3}\)) das Gini-Impurity-Kriterium den Wert \(0.27\bar{7} \approx 0.278\) hat.

Das Gini-Impurity-Kriterium ist sehr wichtig für das Training eines Entscheidungsbaumes. Der Algorithmus probiert im Hintergrund verschiedene Möglichkeiten durch, mit Hilfe der Entscheidungsfragen den Datensatz zu splitten. Zu jedem Split werden dann die zugehörigen Werte des Gini-Impurity-Kriteriums berechnet. Dann wählt der Algorithmus den Split aus, der die höchste Reinheit hat (also den niedrigsten Gini-Impurity-Wert). Gilt das für mehrere Splits, dann wird zufällig ein Split ausgewählt.

Neben dem Gini-Impurity-Kriterium gibt es noch weitere Bewertungsmaße, um einen Entscheidungsbaum zu trainieren. In Scikit-Learn sind die beiden Alternativen log_less und entropy für den Shannonschen Informationsgewinn verfügbar. Wir schauen uns im Folgenden an, wie diese ausgewählt werden können. Wer zuvor sich noch ein wenig mehr mit den Details von Entscheidungsbäumen beschäftigen möchte, kann sich die folgenden Videos ansehen.

Optionales Video “Entscheidungsbäume #2 - Der ID3-Algorithmus” von The Morpheus Tutorials
Optionales Video “Entscheidungsbäume #3 - Entropie und Informationsgewinn” von The Morpheus Tutorials
Optionales Video “ID3 Entscheidungsbaum” von 42 Entwickler

Entscheidungsbäume trainieren#

Der Entscheidungsbaum-Klassifikationsalgorithmus von Scikit-Learn bietet noch weitere Optionen an, wie die Hilfe verrät

help(DecisionTreeClassifier())

oder in der Dokumentation Scikit-Learn → DecisionTreeClassifier() nachgelesen werden kann.

Sowohl bei der Initalisierung des Entscheidungsbaumes können Parameter gesetzt werden, als auch beim Verwenden der verschiedenen Methoden. Tatsächlich haben wir bereits weiter oben aus didaktischen Gründen den Parameter random_state=0 bei der Initialisierung gesetzt, damit immer der gleiche Entscheidungsbaum entsteht. In einem echten Projekt würde dieser Parameter nie verwendet werden.

Probieren Sie andere Werte für den Start des Zufallszahlengenerators aus und testen Sie, was sich verändert, wenn Sie andere Kriterien für das Splitting verwenden.

modell = DecisionTreeClassifier(criterion='entropy', random_state=3)
modell.fit(X,y)

plot_tree(modell, 
    feature_names=['Kilometerstand [km]', 'Preis [EUR]'],
    class_names=['nicht verkauft', 'verkauft']);
../_images/ad63e4f568ac99d3427844458de5a80e2922ca3875460cfd6d0f393724b07932.png

Zusammenfassung und Ausblick#

In diesem Kapitel haben wir das Training von Entscheidungsbäumen mit Hilfe der Bibliothek Scikit-Learn vertieft. Im nächsten Kapitel widmen wir uns den Vor-, aber auch den Nachteilen von Entscheidungsbäumen.